{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-P3OUvJirQdR"
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "44lB2sH-rQdW"
   },
   "source": [
    "# Camera position optimization using differentiable rendering\n",
    "\n",
    "In this tutorial we will learn the [x, y, z] position of a camera given a reference image using differentiable rendering. \n",
    "\n",
    "We will first initialize a renderer with a starting position for the camera. We will then use this to generate an image, compute a loss with the reference image, and finally backpropagate through the entire pipeline to update the position of the camera. \n",
    "\n",
    "This tutorial shows how to:\n",
    "- load a mesh from an `.obj` file\n",
    "- initialize a `Camera`, `Shader` and `Renderer`,\n",
    "- render a mesh\n",
    "- set up an optimization loop with a loss function and optimizer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AZGmIlmWrQdX"
   },
   "source": [
    "##  0. Install and import modules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qkX7DiM6rmeM"
   },
   "source": [
    "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 717
    },
    "colab_type": "code",
    "id": "sEVdNGFwripM",
    "outputId": "27047061-a29b-4562-c164-c1288e24c266"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "need_pytorch3d=False\n",
    "try:\n",
    "    import pytorch3d\n",
    "except ModuleNotFoundError:\n",
    "    need_pytorch3d=True\n",
    "if need_pytorch3d:\n",
    "    if torch.__version__.startswith(\"1.9\") and sys.platform.startswith(\"linux\"):\n",
    "        # We try to install PyTorch3D via a released wheel.\n",
    "        version_str=\"\".join([\n",
    "            f\"py3{sys.version_info.minor}_cu\",\n",
    "            torch.version.cuda.replace(\".\",\"\"),\n",
    "            f\"_pyt{torch.__version__[0:5:2]}\"\n",
    "        ])\n",
    "        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
    "    else:\n",
    "        # We try to install PyTorch3D from source.\n",
    "        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n",
    "        !tar xzf 1.10.0.tar.gz\n",
    "        os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n",
    "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "w9mH5iVprQdZ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "import imageio\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage import img_as_ubyte\n",
    "\n",
    "# io utils\n",
    "from pytorch3d.io import load_obj\n",
    "\n",
    "# datastructures\n",
    "from pytorch3d.structures import Meshes\n",
    "\n",
    "# 3D transformations functions\n",
    "from pytorch3d.transforms import Rotate, Translate\n",
    "\n",
    "# rendering components\n",
    "from pytorch3d.renderer import (\n",
    "    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, \n",
    "    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n",
    "    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cpUf2UvirQdc"
   },
   "source": [
    "## 1. Load the Obj\n",
    "\n",
    "We will load an obj file and create a **Meshes** object. **Meshes** is a unique datastructure provided in PyTorch3D for working with **batches of meshes of different sizes**. It has several useful class methods which are used in the rendering pipeline. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8d-oREfkrt_Z"
   },
   "source": [
    "If you are running this notebook locally after cloning the PyTorch3D repository, the mesh will already be available. **If using Google Colab, fetch the mesh and save it at the path `data/`**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "sD5KcLuJr0PL",
    "outputId": "e65061fa-dbd5-4c06-b559-3592632983ee"
   },
   "outputs": [],
   "source": [
    "!mkdir -p data\n",
    "!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VWiPKnEIrQdd"
   },
   "outputs": [],
   "source": [
    "# Set the cuda device \n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    torch.cuda.set_device(device)\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "# Load the obj and ignore the textures and materials.\n",
    "verts, faces_idx, _ = load_obj(\"./data/teapot.obj\")\n",
    "faces = faces_idx.verts_idx\n",
    "\n",
    "# Initialize each vertex to be white in color.\n",
    "verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)\n",
    "textures = TexturesVertex(verts_features=verts_rgb.to(device))\n",
    "\n",
    "# Create a Meshes object for the teapot. Here we have only one mesh in the batch.\n",
    "teapot_mesh = Meshes(\n",
    "    verts=[verts.to(device)],   \n",
    "    faces=[faces.to(device)], \n",
    "    textures=textures\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mgtGbQktrQdh"
   },
   "source": [
    "\n",
    "\n",
    "## 2. Optimization setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q6PzKD_NrQdi"
   },
   "source": [
    "### Create a renderer\n",
    "\n",
    "A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthographic/perspective). Here we initialize some of these components and use default values for the rest. \n",
    "\n",
    "For optimizing the camera position we will use a renderer which produces a **silhouette** of the object only and does not apply any **lighting** or **shading**. We will also initialize another renderer which applies full **Phong shading** and use this for visualizing the outputs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KPlby75GrQdj"
   },
   "outputs": [],
   "source": [
    "# Initialize a perspective camera.\n",
    "cameras = FoVPerspectiveCameras(device=device)\n",
    "\n",
    "# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of \n",
    "# edges. Refer to blending.py for more details. \n",
    "blend_params = BlendParams(sigma=1e-4, gamma=1e-4)\n",
    "\n",
    "# Define the settings for rasterization and shading. Here we set the output image to be of size\n",
    "# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that \n",
    "# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for \n",
    "# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of \n",
    "# the difference between naive and coarse-to-fine rasterization. \n",
    "raster_settings = RasterizationSettings(\n",
    "    image_size=256, \n",
    "    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, \n",
    "    faces_per_pixel=100, \n",
    ")\n",
    "\n",
    "# Create a silhouette mesh renderer by composing a rasterizer and a shader. \n",
    "silhouette_renderer = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=cameras, \n",
    "        raster_settings=raster_settings\n",
    "    ),\n",
    "    shader=SoftSilhouetteShader(blend_params=blend_params)\n",
    ")\n",
    "\n",
    "\n",
    "# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.\n",
    "raster_settings = RasterizationSettings(\n",
    "    image_size=256, \n",
    "    blur_radius=0.0, \n",
    "    faces_per_pixel=1, \n",
    ")\n",
    "# We can add a point light in front of the object. \n",
    "lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))\n",
    "phong_renderer = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=cameras, \n",
    "        raster_settings=raster_settings\n",
    "    ),\n",
    "    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "osOy2OIJrQdn"
   },
   "source": [
    "### Create a reference image\n",
    "\n",
    "We will first position the teapot and generate an image. We use helper functions to rotate the teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use both renderers and visualize the silhouette and full shaded image. \n",
    "\n",
    "The world coordinate system is defined as +Y up, +X left and +Z in. The teapot in world coordinates has the spout pointing to the left. \n",
    "\n",
    "We defined a camera which is positioned on the positive z axis hence sees the spout to the right. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 305
    },
    "colab_type": "code",
    "id": "EjJrW7qerQdo",
    "outputId": "93545b65-269e-4719-f4a2-52cbc6c9c974"
   },
   "outputs": [],
   "source": [
    "# Select the viewpoint using spherical angles  \n",
    "distance = 3   # distance from camera to the object\n",
    "elevation = 50.0   # angle of elevation in degrees\n",
    "azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. \n",
    "\n",
    "# Get the position of the camera based on the spherical angles\n",
    "R, T = look_at_view_transform(distance, elevation, azimuth, device=device)\n",
    "\n",
    "# Render the teapot providing the values of R and T. \n",
    "silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
    "image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
    "\n",
    "silhouette = silhouette.cpu().numpy()\n",
    "image_ref = image_ref.cpu().numpy()\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.imshow(silhouette.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image\n",
    "plt.grid(False)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.imshow(image_ref.squeeze())\n",
    "plt.grid(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "plBJwEslrQdt"
   },
   "source": [
    "### Set up a basic model \n",
    "\n",
    "Here we create a simple model class and initialize a parameter for the camera position. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YBbP1-EDrQdu"
   },
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, meshes, renderer, image_ref):\n",
    "        super().__init__()\n",
    "        self.meshes = meshes\n",
    "        self.device = meshes.device\n",
    "        self.renderer = renderer\n",
    "        \n",
    "        # Get the silhouette of the reference RGB image by finding all non-white pixel values. \n",
    "        image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))\n",
    "        self.register_buffer('image_ref', image_ref)\n",
    "        \n",
    "        # Create an optimizable parameter for the x, y, z position of the camera. \n",
    "        self.camera_position = nn.Parameter(\n",
    "            torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))\n",
    "\n",
    "    def forward(self):\n",
    "        \n",
    "        # Render the image using the updated camera position. Based on the new position of the \n",
    "        # camera we calculate the rotation and translation matrices\n",
    "        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)\n",
    "        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)\n",
    "        \n",
    "        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)\n",
    "        \n",
    "        # Calculate the silhouette loss\n",
    "        loss = torch.sum((image[..., 3] - self.image_ref) ** 2)\n",
    "        return loss, image\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qCGLSJtfrQdy"
   },
   "source": [
    "## 3. Initialize the model and optimizer\n",
    "\n",
    "Now we can create an instance of the **model** above and set up an **optimizer** for the camera position parameter. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "srZPBU7_rQdz"
   },
   "outputs": [],
   "source": [
    "# We will save images periodically and compose them into a GIF.\n",
    "filename_output = \"./teapot_optimization_demo.gif\"\n",
    "writer = imageio.get_writer(filename_output, mode='I', duration=0.3)\n",
    "\n",
    "# Initialize a model using the renderer, mesh and reference image\n",
    "model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)\n",
    "\n",
    "# Create an optimizer. Here we are using Adam and we pass in the parameters of the model\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dvTLnrWorQd2"
   },
   "source": [
    "### Visualize the starting position and the reference position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 335
    },
    "colab_type": "code",
    "id": "qyRXpP3mrQd3",
    "outputId": "47ecb12a-e68c-47f5-92fc-821a7a9bd661"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "\n",
    "_, image_init = model()\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])\n",
    "plt.grid(False)\n",
    "plt.title(\"Starting position\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.imshow(model.image_ref.cpu().numpy().squeeze())\n",
    "plt.grid(False)\n",
    "plt.title(\"Reference silhouette\");\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aGJu7h-lrQd5"
   },
   "source": [
    "## 4. Run the optimization \n",
    "\n",
    "We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./teapot_optimization_demo.gif` for a cool gif of the optimization process!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "79d7fc84b5564206ab64b2759474da04",
      "02acadb61c3949fcaeab177fd184c388",
      "efd9860908c64bfe9d47118be4734648",
      "f8df7c6efb7d47f5be760a39b4bdbcf8",
      "d8a109658c364a00ab4d298112dac6db",
      "2d05db82cc99482bb3d62b6d4e5b1a98",
      "c621d425e2c8426c8cd4f9136d392af1",
      "3df8063f307040ebb8ff8e2f26ccf729"
     ]
    },
    "colab_type": "code",
    "id": "HvnK5VI5rQd6",
    "outputId": "4019c697-3fc6-4c7b-cdfe-225633cc0d60"
   },
   "outputs": [],
   "source": [
    "loop = tqdm(range(200))\n",
    "for i in loop:\n",
    "    optimizer.zero_grad()\n",
    "    loss, _ = model()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    loop.set_description('Optimizing (loss %.4f)' % loss.data)\n",
    "    \n",
    "    if loss.item() < 200:\n",
    "        break\n",
    "    \n",
    "    # Save outputs to create a GIF. \n",
    "    if i % 10 == 0:\n",
    "        R = look_at_rotation(model.camera_position[None, :], device=model.device)\n",
    "        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)\n",
    "        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)\n",
    "        image = image[0, ..., :3].detach().squeeze().cpu().numpy()\n",
    "        image = img_as_ubyte(image)\n",
    "        writer.append_data(image)\n",
    "        \n",
    "        plt.figure()\n",
    "        plt.imshow(image[..., :3])\n",
    "        plt.title(\"iter: %d, loss: %0.2f\" % (i, loss.data))\n",
    "        plt.axis(\"off\")\n",
    "    \n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mWj80P_SsPTN"
   },
   "source": [
    "## 5. Conclusion \n",
    "\n",
    "In this tutorial we learnt how to **load** a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an **Renderer** consisting of a **Rasterizer** and a **Shader**, set up an optimization loop including a **Model** and a **loss function**, and run  the optimization. "
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "anp_metadata": {
   "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
  },
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "colab": {
   "name": "camera_position_optimization_with_differentiable_rendering.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "disseminate_notebook_info": {
   "backup_notebook_id": "1062179640844868"
  },
  "kernelspec": {
   "display_name": "pytorch3d (local)",
   "language": "python",
   "name": "pytorch3d_local"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
